load libraries

In [1]:
import os
import cv2
import glob
import numpy as np
from keras.models import *
from keras.layers import *
from keras.applications import *
from keras.preprocessing.image import *
Using TensorFlow backend.

加载数据集

In [2]:
basedir = "/ext/Data/distracted_driver_detection/"

model_image_size = 224

print("-------- loading train data")
X_train = list()
y_train = list()
for i in range(10):
    dir = os.path.join(basedir, "train", "c%d"%i)
    image_files = glob.glob(os.path.join(dir,"*.jpg"))
    print("loding {}, image count={}".format(dir, len(image_files)))
    for image_file in image_files:
        image = cv2.imread(image_file)
        X_train.append(cv2.resize(image, (model_image_size, model_image_size)))
        label = np.zeros(10, dtype=np.uint8)
        label[i]=1
        y_train.append(label)
X_train = np.array(X_train)
y_train = np.array(y_train)
        
print("-------- loading valid data")
X_valid = list()
y_valid = list()
for i in range(10):
    dir = os.path.join(basedir, "valid", "c%d"%i)
    image_files = glob.glob(os.path.join(dir,"*.jpg"))
    print("loding {}, image count={}".format(dir, len(image_files)))
    for image_file in image_files:
        image = cv2.imread(image_file)
        X_valid.append(cv2.resize(image, (model_image_size, model_image_size)))
        label = np.zeros(10, dtype=np.uint8)
        label[i]=1
        y_valid.append(label)
X_valid = np.array(X_valid)
y_valid = np.array(y_valid)
-------- loading train data
loding /ext/Data/distracted_driver_detection/train/c0, image count=2308
loding /ext/Data/distracted_driver_detection/train/c1, image count=2096
loding /ext/Data/distracted_driver_detection/train/c2, image count=2136
loding /ext/Data/distracted_driver_detection/train/c3, image count=2185
loding /ext/Data/distracted_driver_detection/train/c4, image count=2160
loding /ext/Data/distracted_driver_detection/train/c5, image count=2152
loding /ext/Data/distracted_driver_detection/train/c6, image count=2164
loding /ext/Data/distracted_driver_detection/train/c7, image count=1843
loding /ext/Data/distracted_driver_detection/train/c8, image count=1771
loding /ext/Data/distracted_driver_detection/train/c9, image count=1972
-------- loading valid data
loding /ext/Data/distracted_driver_detection/valid/c0, image count=181
loding /ext/Data/distracted_driver_detection/valid/c1, image count=171
loding /ext/Data/distracted_driver_detection/valid/c2, image count=181
loding /ext/Data/distracted_driver_detection/valid/c3, image count=161
loding /ext/Data/distracted_driver_detection/valid/c4, image count=166
loding /ext/Data/distracted_driver_detection/valid/c5, image count=160
loding /ext/Data/distracted_driver_detection/valid/c6, image count=161
loding /ext/Data/distracted_driver_detection/valid/c7, image count=159
loding /ext/Data/distracted_driver_detection/valid/c8, image count=140
loding /ext/Data/distracted_driver_detection/valid/c9, image count=157

分为训练集和验证集

In [5]:
print(X_train.shape)
print(y_train.shape)
print(X_valid.shape)
print(y_valid.shape)
(20787, 224, 224, 3)
(20787, 10)
(1637, 224, 224, 3)
(1637, 10)
In [6]:
base_model = ResNet50(input_tensor=Input((model_image_size, model_image_size, 3)), weights='imagenet', include_top=False)

for layers in base_model.layers:
    layers.trainable = False

x = GlobalAveragePooling2D()(base_model.output)
x = Dropout(0.25)(x)
x = Dense(10, activation='softmax')(x)
model = Model(base_model.input, x)
model.compile(optimizer='adadelta', loss='binary_crossentropy', metrics=['accuracy'])

#     model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
print("done")
done

训练模型

In [7]:
model.fit(X_train, y_train, batch_size=16, epochs=10, validation_data=(X_valid, y_valid))

model.save("models/resnet50-mymodel.h5")
Train on 20787 samples, validate on 1637 samples
Epoch 1/10
20787/20787 [==============================] - 96s - loss: 0.1716 - acc: 0.9375 - val_loss: 0.2335 - val_acc: 0.9104 - ETA: 81s - loss: 0.2890 - acc: 0.9016 - ETA: 77s - loss: 0.2801 - acc: 0.9028 - ETA: 76s - loss: 0.2776 - acc: 0.9031 - ETA: 73s - loss: 0.2720 - acc: 0.9041 - ETA: 72s - loss: 0.2678 - acc: 0.9048 - ETA: 70s - loss: 0.2658 - acc: 0.9053 - ETA: 61s - loss: 0.2469 - acc: 0.9104 - ETA: 59s - loss: 0.2435 - acc: 0.9113 - ETA: 59s - loss: 0.2433 - acc: 0.9114 - ETA: 59s - loss: 0.2429 - acc: 0.9115 - ETA: 53s - loss: 0.2313 - acc: 0.9149 - ETA: 48s - loss: 0.2247 - acc: 0.9172 - ETA: 47s - loss: 0.2222 - acc: 0.9179 - ETA: 45s - loss: 0.2187 - acc: 0.9191 - ETA: 44s - loss: 0.2185 - acc: 0.9192 - ETA: 44s - loss: 0.2184 - acc: 0.9192 - ETA: 44s - loss: 0.2182 - acc: 0.9193 - ETA: 44s - loss: 0.2177 - acc: 0.9195 - ETA: 44s - loss: 0.2176 - acc: 0.9195 - ETA: 44s - loss: 0.2172 - acc: 0.9197 - ETA: 42s - loss: 0.2157 - acc: 0.9202 - ETA: 41s - loss: 0.2131 - acc: 0.9212 - ETA: 40s - loss: 0.2124 - acc: 0.9214 - ETA: 40s - loss: 0.2116 - acc: 0.9217 - ETA: 39s - loss: 0.2111 - acc: 0.9220 - ETA: 39s - loss: 0.2110 - acc: 0.9220 - ETA: 37s - loss: 0.2077 - acc: 0.9231 - ETA: 37s - loss: 0.2077 - acc: 0.9232 - ETA: 36s - loss: 0.2072 - acc: 0.9234 - ETA: 35s - loss: 0.2058 - acc: 0.9238 - ETA: 34s - loss: 0.2050 - acc: 0.9241 - ETA: 34s - loss: 0.2049 - acc: 0.9242 - ETA: 34s - loss: 0.2039 - acc: 0.9246 - ETA: 33s - loss: 0.2037 - acc: 0.9246 - ETA: 33s - loss: 0.2032 - acc: 0.9248 - ETA: 32s - loss: 0.2021 - acc: 0.9252 - ETA: 31s - loss: 0.2009 - acc: 0.9256 - ETA: 30s - loss: 0.2005 - acc: 0.9258 - ETA: 30s - loss: 0.2003 - acc: 0.9258 - ETA: 29s - loss: 0.1990 - acc: 0.9264 - ETA: 29s - loss: 0.1983 - acc: 0.9267 - ETA: 29s - loss: 0.1980 - acc: 0.9268 - ETA: 27s - loss: 0.1968 - acc: 0.9272 - ETA: 27s - loss: 0.1967 - acc: 0.9273 - ETA: 27s - loss: 0.1959 - acc: 0.9276 - ETA: 26s - loss: 0.1958 - acc: 0.9276 - ETA: 25s - loss: 0.1944 - acc: 0.9281 - ETA: 25s - loss: 0.1940 - acc: 0.9283 - ETA: 24s - loss: 0.1935 - acc: 0.9286 - ETA: 23s - loss: 0.1919 - acc: 0.9292 - ETA: 23s - loss: 0.1918 - acc: 0.9292 - ETA: 23s - loss: 0.1916 - acc: 0.9293 - ETA: 21s - loss: 0.1903 - acc: 0.9298 - ETA: 21s - loss: 0.1902 - acc: 0.9298 - ETA: 21s - loss: 0.1899 - acc: 0.9300 - ETA: 20s - loss: 0.1891 - acc: 0.9302 - ETA: 19s - loss: 0.1879 - acc: 0.9307 - ETA: 19s - loss: 0.1876 - acc: 0.9308 - ETA: 18s - loss: 0.1873 - acc: 0.9310 - ETA: 18s - loss: 0.1871 - acc: 0.9310 - ETA: 17s - loss: 0.1866 - acc: 0.9313 - ETA: 16s - loss: 0.1859 - acc: 0.9316 - ETA: 16s - loss: 0.1858 - acc: 0.9316 - ETA: 16s - loss: 0.1856 - acc: 0.9317 - ETA: 15s - loss: 0.1849 - acc: 0.9320 - ETA: 14s - loss: 0.1841 - acc: 0.9323 - ETA: 14s - loss: 0.1840 - acc: 0.9324 - ETA: 14s - loss: 0.1838 - acc: 0.9324 - ETA: 14s - loss: 0.1837 - acc: 0.9325 - ETA: 14s - loss: 0.1836 - acc: 0.9325 - ETA: 13s - loss: 0.1829 - acc: 0.9328 - ETA: 13s - loss: 0.1826 - acc: 0.9329 - ETA: 12s - loss: 0.1824 - acc: 0.9330 - ETA: 12s - loss: 0.1823 - acc: 0.9330 - ETA: 12s - loss: 0.1820 - acc: 0.9331 - ETA: 11s - loss: 0.1812 - acc: 0.9335 - ETA: 11s - loss: 0.1812 - acc: 0.9335 - ETA: 11s - loss: 0.1811 - acc: 0.9335 - ETA: 11s - loss: 0.1810 - acc: 0.9335 - ETA: 11s - loss: 0.1809 - acc: 0.9336 - ETA: 11s - loss: 0.1809 - acc: 0.9336 - ETA: 11s - loss: 0.1809 - acc: 0.9336 - ETA: 11s - loss: 0.1808 - acc: 0.9336 - ETA: 10s - loss: 0.1802 - acc: 0.9339 - ETA: 10s - loss: 0.1799 - acc: 0.9340 - ETA: 9s - loss: 0.1792 - acc: 0.9343 - ETA: 8s - loss: 0.1787 - acc: 0.9345 - ETA: 8s - loss: 0.1783 - acc: 0.9347 - ETA: 6s - loss: 0.1767 - acc: 0.9353 - ETA: 5s - loss: 0.1765 - acc: 0.9354 - ETA: 3s - loss: 0.1741 - acc: 0.9365 - ETA: 2s - loss: 0.1735 - acc: 0.9367
Epoch 2/10
20787/20787 [==============================] - 94s - loss: 0.0831 - acc: 0.9735 - val_loss: 0.1959 - val_acc: 0.9249 - ETA: 82s - loss: 0.1025 - acc: 0.9649 - ETA: 81s - loss: 0.1046 - acc: 0.9641 - ETA: 74s - loss: 0.1007 - acc: 0.9665 - ETA: 74s - loss: 0.1006 - acc: 0.9666 - ETA: 73s - loss: 0.1000 - acc: 0.9668 - ETA: 73s - loss: 0.0996 - acc: 0.9670 - ETA: 71s - loss: 0.0996 - acc: 0.9673 - ETA: 71s - loss: 0.0995 - acc: 0.9673 - ETA: 71s - loss: 0.0994 - acc: 0.9674 - ETA: 70s - loss: 0.0993 - acc: 0.9674 - ETA: 68s - loss: 0.0977 - acc: 0.9680 - ETA: 68s - loss: 0.0975 - acc: 0.9681 - ETA: 66s - loss: 0.0973 - acc: 0.9681 - ETA: 65s - loss: 0.0970 - acc: 0.9682 - ETA: 63s - loss: 0.0962 - acc: 0.9686 - ETA: 61s - loss: 0.0955 - acc: 0.9689 - ETA: 60s - loss: 0.0954 - acc: 0.9690 - ETA: 59s - loss: 0.0952 - acc: 0.9691 - ETA: 56s - loss: 0.0945 - acc: 0.9694 - ETA: 55s - loss: 0.0942 - acc: 0.9695 - ETA: 55s - loss: 0.0942 - acc: 0.9695 - ETA: 54s - loss: 0.0940 - acc: 0.9697 - ETA: 53s - loss: 0.0938 - acc: 0.9697 - ETA: 53s - loss: 0.0938 - acc: 0.9697 - ETA: 53s - loss: 0.0937 - acc: 0.9698 - ETA: 52s - loss: 0.0935 - acc: 0.9698 - ETA: 50s - loss: 0.0929 - acc: 0.9700 - ETA: 49s - loss: 0.0926 - acc: 0.9701 - ETA: 47s - loss: 0.0917 - acc: 0.9705 - ETA: 45s - loss: 0.0911 - acc: 0.9707 - ETA: 45s - loss: 0.0911 - acc: 0.9707 - ETA: 44s - loss: 0.0910 - acc: 0.9707 - ETA: 44s - loss: 0.0909 - acc: 0.9708 - ETA: 43s - loss: 0.0907 - acc: 0.9708 - ETA: 38s - loss: 0.0899 - acc: 0.9710 - ETA: 37s - loss: 0.0897 - acc: 0.9711 - ETA: 36s - loss: 0.0894 - acc: 0.9712 - ETA: 35s - loss: 0.0895 - acc: 0.9711 - ETA: 34s - loss: 0.0892 - acc: 0.9713 - ETA: 33s - loss: 0.0890 - acc: 0.9713 - ETA: 33s - loss: 0.0890 - acc: 0.9713 - ETA: 32s - loss: 0.0886 - acc: 0.9715 - ETA: 32s - loss: 0.0885 - acc: 0.9715 - ETA: 31s - loss: 0.0885 - acc: 0.9715 - ETA: 31s - loss: 0.0884 - acc: 0.9716 - ETA: 27s - loss: 0.0880 - acc: 0.9717 - ETA: 26s - loss: 0.0877 - acc: 0.9719 - ETA: 26s - loss: 0.0876 - acc: 0.9719 - ETA: 26s - loss: 0.0875 - acc: 0.9719 - ETA: 25s - loss: 0.0875 - acc: 0.9720 - ETA: 24s - loss: 0.0874 - acc: 0.9720 - ETA: 24s - loss: 0.0873 - acc: 0.9720 - ETA: 22s - loss: 0.0870 - acc: 0.9721 - ETA: 22s - loss: 0.0870 - acc: 0.9721 - ETA: 22s - loss: 0.0870 - acc: 0.9721 - ETA: 21s - loss: 0.0869 - acc: 0.9721 - ETA: 20s - loss: 0.0866 - acc: 0.9722 - ETA: 18s - loss: 0.0862 - acc: 0.9723 - ETA: 18s - loss: 0.0861 - acc: 0.9723 - ETA: 17s - loss: 0.0859 - acc: 0.9724 - ETA: 17s - loss: 0.0858 - acc: 0.9725 - ETA: 16s - loss: 0.0858 - acc: 0.9724 - ETA: 16s - loss: 0.0858 - acc: 0.9724 - ETA: 15s - loss: 0.0855 - acc: 0.9725 - ETA: 14s - loss: 0.0853 - acc: 0.9725 - ETA: 13s - loss: 0.0852 - acc: 0.9726 - ETA: 11s - loss: 0.0848 - acc: 0.9728 - ETA: 10s - loss: 0.0847 - acc: 0.9728 - ETA: 10s - loss: 0.0847 - acc: 0.9728 - ETA: 10s - loss: 0.0847 - acc: 0.9728 - ETA: 9s - loss: 0.0845 - acc: 0.9729 - ETA: 9s - loss: 0.0846 - acc: 0.9728 - ETA: 7s - loss: 0.0842 - acc: 0.9729 - ETA: 7s - loss: 0.0842 - acc: 0.9730 - ETA: 3s - loss: 0.0837 - acc: 0.9732 - ETA: 2s - loss: 0.0836 - acc: 0.9733 - ETA: 2s - loss: 0.0835 - acc: 0.9733 - ETA: 0s - loss: 0.0832 - acc: 0.9734
Epoch 3/10
20787/20787 [==============================] - 94s - loss: 0.0617 - acc: 0.9814 - val_loss: 0.1772 - val_acc: 0.9311 - ETA: 85s - loss: 0.0720 - acc: 0.9775 - ETA: 85s - loss: 0.0730 - acc: 0.9773 - ETA: 84s - loss: 0.0725 - acc: 0.9774 - ETA: 82s - loss: 0.0723 - acc: 0.9778 - ETA: 78s - loss: 0.0672 - acc: 0.9801 - ETA: 78s - loss: 0.0669 - acc: 0.9802 - ETA: 78s - loss: 0.0673 - acc: 0.9801 - ETA: 75s - loss: 0.0671 - acc: 0.9798 - ETA: 73s - loss: 0.0670 - acc: 0.9800 - ETA: 73s - loss: 0.0676 - acc: 0.9799 - ETA: 70s - loss: 0.0681 - acc: 0.9796 - ETA: 66s - loss: 0.0671 - acc: 0.9798 - ETA: 65s - loss: 0.0672 - acc: 0.9798 - ETA: 65s - loss: 0.0669 - acc: 0.9799 - ETA: 61s - loss: 0.0658 - acc: 0.9802 - ETA: 60s - loss: 0.0658 - acc: 0.9802 - ETA: 59s - loss: 0.0659 - acc: 0.9802 - ETA: 53s - loss: 0.0658 - acc: 0.9802 - ETA: 50s - loss: 0.0654 - acc: 0.9803 - ETA: 49s - loss: 0.0653 - acc: 0.9803 - ETA: 48s - loss: 0.0653 - acc: 0.9802 - ETA: 48s - loss: 0.0652 - acc: 0.9803 - ETA: 47s - loss: 0.0651 - acc: 0.9804 - ETA: 47s - loss: 0.0649 - acc: 0.9805 - ETA: 46s - loss: 0.0648 - acc: 0.9805 - ETA: 45s - loss: 0.0647 - acc: 0.9806 - ETA: 45s - loss: 0.0646 - acc: 0.9806 - ETA: 45s - loss: 0.0646 - acc: 0.9807 - ETA: 40s - loss: 0.0647 - acc: 0.9807 - ETA: 40s - loss: 0.0645 - acc: 0.9807 - ETA: 37s - loss: 0.0645 - acc: 0.9806 - ETA: 37s - loss: 0.0643 - acc: 0.9807 - ETA: 35s - loss: 0.0643 - acc: 0.9806 - ETA: 34s - loss: 0.0641 - acc: 0.9807 - ETA: 33s - loss: 0.0641 - acc: 0.9807 - ETA: 33s - loss: 0.0641 - acc: 0.9807 - ETA: 33s - loss: 0.0641 - acc: 0.9807 - ETA: 32s - loss: 0.0641 - acc: 0.9807 - ETA: 32s - loss: 0.0640 - acc: 0.9807 - ETA: 32s - loss: 0.0640 - acc: 0.9807 - ETA: 31s - loss: 0.0640 - acc: 0.9807 - ETA: 30s - loss: 0.0640 - acc: 0.9807 - ETA: 30s - loss: 0.0641 - acc: 0.9806 - ETA: 28s - loss: 0.0640 - acc: 0.9807 - ETA: 27s - loss: 0.0639 - acc: 0.9807 - ETA: 25s - loss: 0.0640 - acc: 0.9806 - ETA: 19s - loss: 0.0634 - acc: 0.9808 - ETA: 19s - loss: 0.0634 - acc: 0.9808 - ETA: 16s - loss: 0.0630 - acc: 0.9809 - ETA: 15s - loss: 0.0629 - acc: 0.9810 - ETA: 14s - loss: 0.0628 - acc: 0.9811 - ETA: 14s - loss: 0.0627 - acc: 0.9811 - ETA: 13s - loss: 0.0626 - acc: 0.9811 - ETA: 10s - loss: 0.0624 - acc: 0.9812 - ETA: 9s - loss: 0.0623 - acc: 0.9812 - ETA: 8s - loss: 0.0622 - acc: 0.9813 - ETA: 5s - loss: 0.0618 - acc: 0.9814 - ETA: 3s - loss: 0.0618 - acc: 0.9814 - ETA: 3s - loss: 0.0618 - acc: 0.9814 - ETA: 2s - loss: 0.0617 - acc: 0.9814 - ETA: 0s - loss: 0.0617 - acc: 0.9814
Epoch 4/10
20787/20787 [==============================] - 91s - loss: 0.0516 - acc: 0.9845 - val_loss: 0.1970 - val_acc: 0.9234 - ETA: 85s - loss: 0.0577 - acc: 0.9840 - ETA: 84s - loss: 0.0585 - acc: 0.9832 - ETA: 84s - loss: 0.0602 - acc: 0.9834 - ETA: 83s - loss: 0.0571 - acc: 0.9840 - ETA: 81s - loss: 0.0582 - acc: 0.9833 - ETA: 80s - loss: 0.0565 - acc: 0.9839 - ETA: 80s - loss: 0.0557 - acc: 0.9842 - ETA: 79s - loss: 0.0553 - acc: 0.9843 - ETA: 77s - loss: 0.0550 - acc: 0.9840 - ETA: 77s - loss: 0.0550 - acc: 0.9839 - ETA: 77s - loss: 0.0552 - acc: 0.9837 - ETA: 75s - loss: 0.0541 - acc: 0.9843 - ETA: 75s - loss: 0.0540 - acc: 0.9842 - ETA: 73s - loss: 0.0547 - acc: 0.9840 - ETA: 71s - loss: 0.0553 - acc: 0.9834 - ETA: 70s - loss: 0.0550 - acc: 0.9835 - ETA: 70s - loss: 0.0551 - acc: 0.9835 - ETA: 69s - loss: 0.0547 - acc: 0.9836 - ETA: 69s - loss: 0.0548 - acc: 0.9836 - ETA: 69s - loss: 0.0547 - acc: 0.9836 - ETA: 65s - loss: 0.0549 - acc: 0.9836 - ETA: 64s - loss: 0.0546 - acc: 0.9837 - ETA: 62s - loss: 0.0547 - acc: 0.9836 - ETA: 62s - loss: 0.0545 - acc: 0.9837 - ETA: 59s - loss: 0.0534 - acc: 0.9844
Epoch 5/10
20787/20787 [==============================] - 90s - loss: 0.0454 - acc: 0.9859 - val_loss: 0.1926 - val_acc: 0.9258
Epoch 6/10
20787/20787 [==============================] - 90s - loss: 0.0409 - acc: 0.9877 - val_loss: 0.1868 - val_acc: 0.9271
Epoch 7/10
20787/20787 [==============================] - 90s - loss: 0.0380 - acc: 0.9884 - val_loss: 0.1920 - val_acc: 0.9287
Epoch 8/10
20787/20787 [==============================] - 90s - loss: 0.0355 - acc: 0.9891 - val_loss: 0.1916 - val_acc: 0.9293
Epoch 9/10
20787/20787 [==============================] - 90s - loss: 0.0342 - acc: 0.9894 - val_loss: 0.1978 - val_acc: 0.9282
Epoch 10/10
20787/20787 [==============================] - 90s - loss: 0.0326 - acc: 0.9899 - val_loss: 0.1999 - val_acc: 0.9258
In [8]:
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.models import *

model = load_model("models/resnet50-mymodel.h5")
print("load successed")

SVG(model_to_dot(model).create(prog='dot', format='svg'))
load successed
Out[8]:
G 139899171323744 input_1: InputLayer 139899171323856 conv1: Conv2D 139899171323744->139899171323856 139899171386032 bn_conv1: BatchNormalization 139899171323856->139899171386032 139899171385920 activation_1: Activation 139899171386032->139899171385920 139899171386368 max_pooling2d_1: MaxPooling2D 139899171385920->139899171386368 139899171386648 res2a_branch2a: Conv2D 139899171386368->139899171386648 139899171389000 res2a_branch1: Conv2D 139899171386368->139899171389000 139899171387040 bn2a_branch2a: BatchNormalization 139899171386648->139899171387040 139899171387768 activation_2: Activation 139899171387040->139899171387768 139899171387936 res2a_branch2b: Conv2D 139899171387768->139899171387936 139899171388272 bn2a_branch2b: BatchNormalization 139899171387936->139899171388272 139899171388608 activation_3: Activation 139899171388272->139899171388608 139899171388664 res2a_branch2c: Conv2D 139899171388608->139899171388664 139903191690488 bn2a_branch2c: BatchNormalization 139899171388664->139903191690488 139899171426656 bn2a_branch1: BatchNormalization 139899171389000->139899171426656 139899171426936 add_1: Add 139903191690488->139899171426936 139899171426656->139899171426936 139899171426992 activation_4: Activation 139899171426936->139899171426992 139899171427048 res2b_branch2a: Conv2D 139899171426992->139899171427048 139899171429176 add_2: Add 139899171426992->139899171429176 139899171427384 bn2b_branch2a: BatchNormalization 139899171427048->139899171427384 139899171427720 activation_5: Activation 139899171427384->139899171427720 139899171427776 res2b_branch2b: Conv2D 139899171427720->139899171427776 139899171428112 bn2b_branch2b: BatchNormalization 139899171427776->139899171428112 139899171428448 activation_6: Activation 139899171428112->139899171428448 139899171428504 res2b_branch2c: Conv2D 139899171428448->139899171428504 139899171428840 bn2b_branch2c: BatchNormalization 139899171428504->139899171428840 139899171428840->139899171429176 139899171429232 activation_7: Activation 139899171429176->139899171429232 139899171429288 res2c_branch2a: Conv2D 139899171429232->139899171429288 139899171451960 add_3: Add 139899171429232->139899171451960 139899171429624 bn2c_branch2a: BatchNormalization 139899171429288->139899171429624 139899171429960 activation_8: Activation 139899171429624->139899171429960 139899171430016 res2c_branch2b: Conv2D 139899171429960->139899171430016 139899171389392 bn2c_branch2b: BatchNormalization 139899171430016->139899171389392 139899171451232 activation_9: Activation 139899171389392->139899171451232 139899171451288 res2c_branch2c: Conv2D 139899171451232->139899171451288 139899171451624 bn2c_branch2c: BatchNormalization 139899171451288->139899171451624 139899171451624->139899171451960 139899171452016 activation_10: Activation 139899171451960->139899171452016 139899171452072 res3a_branch2a: Conv2D 139899171452016->139899171452072 139899171453864 res3a_branch1: Conv2D 139899171452016->139899171453864 139899171452408 bn3a_branch2a: BatchNormalization 139899171452072->139899171452408 139899171452744 activation_11: Activation 139899171452408->139899171452744 139899171452800 res3a_branch2b: Conv2D 139899171452744->139899171452800 139899171453136 bn3a_branch2b: BatchNormalization 139899171452800->139899171453136 139899171453472 activation_12: Activation 139899171453136->139899171453472 139899171453528 res3a_branch2c: Conv2D 139899171453472->139899171453528 139899171454256 bn3a_branch2c: BatchNormalization 139899171453528->139899171454256 139899171454592 bn3a_branch1: BatchNormalization 139899171453864->139899171454592 139899171454872 add_4: Add 139899171454256->139899171454872 139899171454592->139899171454872 139899171430352 activation_13: Activation 139899171454872->139899171430352 139899171479624 res3b_branch2a: Conv2D 139899171430352->139899171479624 139899171481752 add_5: Add 139899171430352->139899171481752 139899171479960 bn3b_branch2a: BatchNormalization 139899171479624->139899171479960 139899171480296 activation_14: Activation 139899171479960->139899171480296 139899171480352 res3b_branch2b: Conv2D 139899171480296->139899171480352 139899171480688 bn3b_branch2b: BatchNormalization 139899171480352->139899171480688 139899171481024 activation_15: Activation 139899171480688->139899171481024 139899171481080 res3b_branch2c: Conv2D 139899171481024->139899171481080 139899171481416 bn3b_branch2c: BatchNormalization 139899171481080->139899171481416 139899171481416->139899171481752 139899171481808 activation_16: Activation 139899171481752->139899171481808 139899171481864 res3c_branch2a: Conv2D 139899171481808->139899171481864 139899171508632 add_6: Add 139899171481808->139899171508632 139899171482200 bn3c_branch2a: BatchNormalization 139899171481864->139899171482200 139899171482536 activation_17: Activation 139899171482200->139899171482536 139899171482592 res3c_branch2b: Conv2D 139899171482536->139899171482592 139899171482928 bn3c_branch2b: BatchNormalization 139899171482592->139899171482928 139899171483264 activation_18: Activation 139899171482928->139899171483264 139899171483320 res3c_branch2c: Conv2D 139899171483264->139899171483320 139899171454928 bn3c_branch2c: BatchNormalization 139899171483320->139899171454928 139899171454928->139899171508632 139899171508688 activation_19: Activation 139899171508632->139899171508688 139899171508744 res3d_branch2a: Conv2D 139899171508688->139899171508744 139899171510872 add_7: Add 139899171508688->139899171510872 139899171509080 bn3d_branch2a: BatchNormalization 139899171508744->139899171509080 139899171509416 activation_20: Activation 139899171509080->139899171509416 139899171509472 res3d_branch2b: Conv2D 139899171509416->139899171509472 139899171509808 bn3d_branch2b: BatchNormalization 139899171509472->139899171509808 139899171510144 activation_21: Activation 139899171509808->139899171510144 139899171510200 res3d_branch2c: Conv2D 139899171510144->139899171510200 139899171510536 bn3d_branch2c: BatchNormalization 139899171510200->139899171510536 139899171510536->139899171510872 139899171510928 activation_22: Activation 139899171510872->139899171510928 139899171510984 res4a_branch2a: Conv2D 139899171510928->139899171510984 139899171541512 res4a_branch1: Conv2D 139899171510928->139899171541512 139899171511320 bn4a_branch2a: BatchNormalization 139899171510984->139899171511320 139899171511656 activation_23: Activation 139899171511320->139899171511656 139899171511712 res4a_branch2b: Conv2D 139899171511656->139899171511712 139899171512048 bn4a_branch2b: BatchNormalization 139899171511712->139899171512048 139899171483488 activation_24: Activation 139899171512048->139899171483488 139899171541176 res4a_branch2c: Conv2D 139899171483488->139899171541176 139899171541904 bn4a_branch2c: BatchNormalization 139899171541176->139899171541904 139899171542240 bn4a_branch1: BatchNormalization 139899171541512->139899171542240 139899171542520 add_8: Add 139899171541904->139899171542520 139899171542240->139899171542520 139899171542576 activation_25: Activation 139899171542520->139899171542576 139899171542632 res4b_branch2a: Conv2D 139899171542576->139899171542632 139899171544760 add_9: Add 139899171542576->139899171544760 139899171542968 bn4b_branch2a: BatchNormalization 139899171542632->139899171542968 139899171543304 activation_26: Activation 139899171542968->139899171543304 139899171543360 res4b_branch2b: Conv2D 139899171543304->139899171543360 139899171543696 bn4b_branch2b: BatchNormalization 139899171543360->139899171543696 139899171544032 activation_27: Activation 139899171543696->139899171544032 139899171544088 res4b_branch2c: Conv2D 139899171544032->139899171544088 139899171544424 bn4b_branch2c: BatchNormalization 139899171544088->139899171544424 139899171544424->139899171544760 139899171544816 activation_28: Activation 139899171544760->139899171544816 139899171544872 res4c_branch2a: Conv2D 139899171544816->139899171544872 139899171571640 add_10: Add 139899171544816->139899171571640 139899171512272 bn4c_branch2a: BatchNormalization 139899171544872->139899171512272 139899171570184 activation_29: Activation 139899171512272->139899171570184 139899171570240 res4c_branch2b: Conv2D 139899171570184->139899171570240 139899171570576 bn4c_branch2b: BatchNormalization 139899171570240->139899171570576 139899171570912 activation_30: Activation 139899171570576->139899171570912 139899171570968 res4c_branch2c: Conv2D 139899171570912->139899171570968 139899171571304 bn4c_branch2c: BatchNormalization 139899171570968->139899171571304 139899171571304->139899171571640 139899171571696 activation_31: Activation 139899171571640->139899171571696 139899171571752 res4d_branch2a: Conv2D 139899171571696->139899171571752 139899171545040 add_11: Add 139899171571696->139899171545040 139899171572088 bn4d_branch2a: BatchNormalization 139899171571752->139899171572088 139899171572424 activation_32: Activation 139899171572088->139899171572424 139899171572480 res4d_branch2b: Conv2D 139899171572424->139899171572480 139899171572816 bn4d_branch2b: BatchNormalization 139899171572480->139899171572816 139899171573152 activation_33: Activation 139899171572816->139899171573152 139899171573208 res4d_branch2c: Conv2D 139899171573152->139899171573208 139899171573544 bn4d_branch2c: BatchNormalization 139899171573208->139899171573544 139899171573544->139899171545040 139899171602672 activation_34: Activation 139899171545040->139899171602672 139899171602728 res4e_branch2a: Conv2D 139899171602672->139899171602728 139899171604856 add_12: Add 139899171602672->139899171604856 139899171603064 bn4e_branch2a: BatchNormalization 139899171602728->139899171603064 139899171603400 activation_35: Activation 139899171603064->139899171603400 139899171603456 res4e_branch2b: Conv2D 139899171603400->139899171603456 139899171603792 bn4e_branch2b: BatchNormalization 139899171603456->139899171603792 139899171604128 activation_36: Activation 139899171603792->139899171604128 139899171604184 res4e_branch2c: Conv2D 139899171604128->139899171604184 139899171604520 bn4e_branch2c: BatchNormalization 139899171604184->139899171604520 139899171604520->139899171604856 139899171604912 activation_37: Activation 139899171604856->139899171604912 139899171604968 res4f_branch2a: Conv2D 139899171604912->139899171604968 139899171635832 add_13: Add 139899171604912->139899171635832 139899171605304 bn4f_branch2a: BatchNormalization 139899171604968->139899171605304 139899171605640 activation_38: Activation 139899171605304->139899171605640 139899171605696 res4f_branch2b: Conv2D 139899171605640->139899171605696 139899171606032 bn4f_branch2b: BatchNormalization 139899171605696->139899171606032 139899171606368 activation_39: Activation 139899171606032->139899171606368 139899171606424 res4f_branch2c: Conv2D 139899171606368->139899171606424 139899171635496 bn4f_branch2c: BatchNormalization 139899171606424->139899171635496 139899171635496->139899171635832 139899171635888 activation_40: Activation 139899171635832->139899171635888 139899171635944 res5a_branch2a: Conv2D 139899171635888->139899171635944 139899171637736 res5a_branch1: Conv2D 139899171635888->139899171637736 139899171636280 bn5a_branch2a: BatchNormalization 139899171635944->139899171636280 139899171636616 activation_41: Activation 139899171636280->139899171636616 139899171636672 res5a_branch2b: Conv2D 139899171636616->139899171636672 139899171637008 bn5a_branch2b: BatchNormalization 139899171636672->139899171637008 139899171637344 activation_42: Activation 139899171637008->139899171637344 139899171637400 res5a_branch2c: Conv2D 139899171637344->139899171637400 139899171638128 bn5a_branch2c: BatchNormalization 139899171637400->139899171638128 139899171638464 bn5a_branch1: BatchNormalization 139899171637736->139899171638464 139899171638744 add_14: Add 139899171638128->139899171638744 139899171638464->139899171638744 139899171638800 activation_43: Activation 139899171638744->139899171638800 139899171638856 res5b_branch2a: Conv2D 139899171638800->139899171638856 139899171661528 add_15: Add 139899171638800->139899171661528 139899171573712 bn5b_branch2a: BatchNormalization 139899171638856->139899171573712 139899171660072 activation_44: Activation 139899171573712->139899171660072 139899171660128 res5b_branch2b: Conv2D 139899171660072->139899171660128 139899171660464 bn5b_branch2b: BatchNormalization 139899171660128->139899171660464 139899171660800 activation_45: Activation 139899171660464->139899171660800 139899171660856 res5b_branch2c: Conv2D 139899171660800->139899171660856 139899171661192 bn5b_branch2c: BatchNormalization 139899171660856->139899171661192 139899171661192->139899171661528 139899171661584 activation_46: Activation 139899171661528->139899171661584 139899171661640 res5c_branch2a: Conv2D 139899171661584->139899171661640 139899171663768 add_16: Add 139899171661584->139899171663768 139899171661976 bn5c_branch2a: BatchNormalization 139899171661640->139899171661976 139899171662312 activation_47: Activation 139899171661976->139899171662312 139899171662368 res5c_branch2b: Conv2D 139899171662312->139899171662368 139899171662704 bn5c_branch2b: BatchNormalization 139899171662368->139899171662704 139899171663040 activation_48: Activation 139899171662704->139899171663040 139899171663096 res5c_branch2c: Conv2D 139899171663040->139899171663096 139899171663432 bn5c_branch2c: BatchNormalization 139899171663096->139899171663432 139899171663432->139899171663768 139899171639192 activation_49: Activation 139899171663768->139899171639192 139899171688520 avg_pool: AveragePooling2D 139899171639192->139899171688520 139899171688688 global_average_pooling2d_1: GlobalAveragePooling2D 139899171688520->139899171688688 139899171688800 dropout_1: Dropout 139899171688688->139899171688800 139899171688856 dense_1: Dense 139899171688800->139899171688856

CAM 可视化

http://cnnlocalization.csail.mit.edu/

$cam = (P-0.5)*w*output$

  • cam: 类激活图
  • P: 概率
  • output: 卷积层的输出 2048*1
  • w: 卷积核的权重 x*x*2048
In [10]:
z = zip([x.name for x in model.layers], range(len(model.layers)))
for k, v in z:
    print("{} - {}".format(k,v))
input_1 - 0
conv1 - 1
bn_conv1 - 2
activation_1 - 3
max_pooling2d_1 - 4
res2a_branch2a - 5
bn2a_branch2a - 6
activation_2 - 7
res2a_branch2b - 8
bn2a_branch2b - 9
activation_3 - 10
res2a_branch2c - 11
res2a_branch1 - 12
bn2a_branch2c - 13
bn2a_branch1 - 14
add_1 - 15
activation_4 - 16
res2b_branch2a - 17
bn2b_branch2a - 18
activation_5 - 19
res2b_branch2b - 20
bn2b_branch2b - 21
activation_6 - 22
res2b_branch2c - 23
bn2b_branch2c - 24
add_2 - 25
activation_7 - 26
res2c_branch2a - 27
bn2c_branch2a - 28
activation_8 - 29
res2c_branch2b - 30
bn2c_branch2b - 31
activation_9 - 32
res2c_branch2c - 33
bn2c_branch2c - 34
add_3 - 35
activation_10 - 36
res3a_branch2a - 37
bn3a_branch2a - 38
activation_11 - 39
res3a_branch2b - 40
bn3a_branch2b - 41
activation_12 - 42
res3a_branch2c - 43
res3a_branch1 - 44
bn3a_branch2c - 45
bn3a_branch1 - 46
add_4 - 47
activation_13 - 48
res3b_branch2a - 49
bn3b_branch2a - 50
activation_14 - 51
res3b_branch2b - 52
bn3b_branch2b - 53
activation_15 - 54
res3b_branch2c - 55
bn3b_branch2c - 56
add_5 - 57
activation_16 - 58
res3c_branch2a - 59
bn3c_branch2a - 60
activation_17 - 61
res3c_branch2b - 62
bn3c_branch2b - 63
activation_18 - 64
res3c_branch2c - 65
bn3c_branch2c - 66
add_6 - 67
activation_19 - 68
res3d_branch2a - 69
bn3d_branch2a - 70
activation_20 - 71
res3d_branch2b - 72
bn3d_branch2b - 73
activation_21 - 74
res3d_branch2c - 75
bn3d_branch2c - 76
add_7 - 77
activation_22 - 78
res4a_branch2a - 79
bn4a_branch2a - 80
activation_23 - 81
res4a_branch2b - 82
bn4a_branch2b - 83
activation_24 - 84
res4a_branch2c - 85
res4a_branch1 - 86
bn4a_branch2c - 87
bn4a_branch1 - 88
add_8 - 89
activation_25 - 90
res4b_branch2a - 91
bn4b_branch2a - 92
activation_26 - 93
res4b_branch2b - 94
bn4b_branch2b - 95
activation_27 - 96
res4b_branch2c - 97
bn4b_branch2c - 98
add_9 - 99
activation_28 - 100
res4c_branch2a - 101
bn4c_branch2a - 102
activation_29 - 103
res4c_branch2b - 104
bn4c_branch2b - 105
activation_30 - 106
res4c_branch2c - 107
bn4c_branch2c - 108
add_10 - 109
activation_31 - 110
res4d_branch2a - 111
bn4d_branch2a - 112
activation_32 - 113
res4d_branch2b - 114
bn4d_branch2b - 115
activation_33 - 116
res4d_branch2c - 117
bn4d_branch2c - 118
add_11 - 119
activation_34 - 120
res4e_branch2a - 121
bn4e_branch2a - 122
activation_35 - 123
res4e_branch2b - 124
bn4e_branch2b - 125
activation_36 - 126
res4e_branch2c - 127
bn4e_branch2c - 128
add_12 - 129
activation_37 - 130
res4f_branch2a - 131
bn4f_branch2a - 132
activation_38 - 133
res4f_branch2b - 134
bn4f_branch2b - 135
activation_39 - 136
res4f_branch2c - 137
bn4f_branch2c - 138
add_13 - 139
activation_40 - 140
res5a_branch2a - 141
bn5a_branch2a - 142
activation_41 - 143
res5a_branch2b - 144
bn5a_branch2b - 145
activation_42 - 146
res5a_branch2c - 147
res5a_branch1 - 148
bn5a_branch2c - 149
bn5a_branch1 - 150
add_14 - 151
activation_43 - 152
res5b_branch2a - 153
bn5b_branch2a - 154
activation_44 - 155
res5b_branch2b - 156
bn5b_branch2b - 157
activation_45 - 158
res5b_branch2c - 159
bn5b_branch2c - 160
add_15 - 161
activation_46 - 162
res5c_branch2a - 163
bn5c_branch2a - 164
activation_47 - 165
res5c_branch2b - 166
bn5c_branch2b - 167
activation_48 - 168
res5c_branch2c - 169
bn5c_branch2c - 170
add_16 - 171
activation_49 - 172
avg_pool - 173
global_average_pooling2d_1 - 174
dropout_1 - 175
dense_1 - 176
In [11]:
import matplotlib.pyplot as plt
import random
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

def show_heatmap_image(model_show, weights_show):
    test_dir = os.path.join(basedir,  "test", "test" )
    image_files = glob.glob(os.path.join(test_dir,"*"))
    print(len(image_files))
    
    plt.figure(figsize=(12, 14))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        img = cv2.imread(image_files[2000*i+113])
        img = cv2.resize(img, (224, 224))
        x = img.copy()
        x.astype(np.float32)
        out, predictions = model_show.predict(np.expand_dims(x, axis=0))
        predictions = predictions[0]
        out = out[0]
        
        max_idx = np.argmax(predictions)
        prediction = predictions[max_idx]

        status = ["safe driving",  " texting - right",  "phone - right",  "texting - left",  "phone - left",  
                  "operation radio", "drinking", "reaching behind", "hair and makeup", "talking"]

        plt.title('c%d |%s| %.2f%%' % (max_idx , status[max_idx], prediction*100))
    
        cam = (prediction - 0.5) * np.matmul(out, weights_show)
        cam = cam[:,:,max_idx]
        cam -= cam.min()
        cam /= cam.max()
        cam -= 0.2
        cam /= 0.8

        cam = cv2.resize(cam, (224, 224))
        heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
        heatmap[np.where(cam <= 0.2)] = 0

        out = cv2.addWeighted(img, 0.8, heatmap, 0.4, 0)

        plt.axis('off')
        plt.imshow(out[:,:,::-1])
print("done")
done
In [12]:
weights = model.layers[176].get_weights()[0]
layer_output = model.layers[172].output
model2 = Model(model.input, [layer_output, model.output])
print("layer_output {0}".format(layer_output))
print("weights shape {0}".format(weights.shape))
show_heatmap_image(model2, weights)
layer_output Tensor("activation_49_1/Relu:0", shape=(?, 7, 7, 2048), dtype=float32)
weights shape (2048, 10)
79726
In [13]:
weights = model.layers[176].get_weights()[0]
layer_output = model.layers[162].output
model2 = Model(model.input, [layer_output, model.output])
print("layer_output {0}".format(layer_output))
print("weights shape {0}".format(weights.shape))
show_heatmap_image(model2, weights)
layer_output Tensor("activation_46_1/Relu:0", shape=(?, 7, 7, 2048), dtype=float32)
weights shape (2048, 10)
79726
In [14]:
weights = model.layers[176].get_weights()[0]
layer_output = model.layers[152].output
model2 = Model(model.input, [layer_output, model.output])
print("layer_output {0}".format(layer_output))
print("weights shape {0}".format(weights.shape))
show_heatmap_image(model2, weights)
layer_output Tensor("activation_43_1/Relu:0", shape=(?, 7, 7, 2048), dtype=float32)
weights shape (2048, 10)
79726